{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": 3
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python_defaultSpec_1602139106579",
   "display_name": "Python 3.6.9 64-bit ('py36': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "# HarvestText中的关键词算法benchmark"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "from tqdm import tqdm\n",
    "import jieba\n",
    "from collections import defaultdict\n",
    "from harvesttext import HarvestText"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "ht = HarvestText()"
   ]
  },
  {
   "source": [
    "首先，选取的数据集是CLUE整理的CSL关键词预测数据集（[下载地址](https://github.com/CLUEbenchmark/CLUE#6-csl-%E8%AE%BA%E6%96%87%E5%85%B3%E9%94%AE%E8%AF%8D%E8%AF%86%E5%88%AB-keyword-recognition)）。需要先下载并放到本目录的`CSL关键词预测`文件夹下\n",
    "\n",
    "在上面先在开发集上做一些基本的分析及调参。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "3000"
     },
     "metadata": {},
     "execution_count": 9
    }
   ],
   "source": [
    "data_dev = []\n",
    "with open('CSL关键词预测/dev.json', encoding='utf-8') as f:\n",
    "    for line in f:\n",
    "        tmp = json.loads(line)\n",
    "        data_dev.append((tmp['abst'], tmp['keyword']))\n",
    "len(data_dev)"
   ]
  },
  {
   "source": [
    "一些基础的数据探索性分析（EDA）\n",
    "- 每个文档的关键词个数\n",
    "- 关键词的长度分布\n",
    "- 考察分词`seg`的情况和不分词`nseg`的情况，有多少比例的关键词被覆盖。这决定了依赖分词和不依赖分词的算法所能达到的理论recall上限。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_keywords = 0\n",
    "recalls = {'seg':0, 'nseg':0}\n",
    "kwd_cnt = defaultdict(int)\n",
    "kwd_len_cnt = defaultdict(int)\n",
    "for abst, kwds in data_dev:\n",
    "    kwd_cnt[len(kwds)] += 1\n",
    "    words = set(jieba.lcut(abst))\n",
    "    all_keywords += len(kwds)\n",
    "    recalls['seg'] += len(set(kwds) & words)\n",
    "    recalls['nseg'] += sum(int(kwd in abst) for kwd in kwds)\n",
    "    for kwd in kwds:\n",
    "        kwd_len_cnt[len(kwd)] += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "defaultdict(<class 'int'>, {4: 1814, 3: 1128, 2: 58})\n"
    }
   ],
   "source": [
    "print(kwd_cnt)"
   ]
  },
  {
   "source": [
    "每篇文档的关键词数量在2-4之间"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "1     0.004277\n2     0.260134\n3     0.387970\n4     0.702864\n5     0.812756\n6     0.904239\n7     0.937151\n8     0.956489\n9     0.971551\n10    0.980104\n11    0.988100\n12    0.991633\n13    0.995258\n14    0.995816\n15    0.996281\n16    0.997583\n17    0.998791\n18    0.999256\n19    0.999442\n20    0.999907\n31    1.000000\ndtype: float64"
     },
     "metadata": {},
     "execution_count": 14
    }
   ],
   "source": [
    "# 关键词长度的累积概率分布\n",
    "pd.Series(kwd_len_cnt).sort_index().cumsum() / sum(kwd_len_cnt.values())"
   ]
  },
  {
   "source": [
    "存在很长的关键词，以一个词而不是多词词组为单元的关键词算法无法处理这些情况，不过4个字以内也已经可以覆盖70%"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "{'seg': 0.3697471178876906, 'nseg': 0.7791000371885459}\n"
    }
   ],
   "source": [
    "for k in recalls:\n",
    "    recalls[k] /= all_keywords\n",
    "print(recalls)"
   ]
  },
  {
   "source": [
    "上述情况说明，依赖jieba分词的算法在这个数据集上最多只能达到36.97%的recall，而其他从原文直接中抽取方法（新词发现，序列标注等）有可能达到77.91%。\n",
    "\n",
    "下面的算法，因此在数值上不会有很好的表现，不过依旧可以为比较和调参提供一些参考。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "给出一个关键词抽取的示例，包括`textrank`和HarvestText封装jieba并配置好参数和停用词的`jieba_tfidf`。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "随机噪声雷达通常利用时域相关完成脉冲压缩从而进行目标检测.该文根据压缩感知理论提出一种适用于噪声雷达目标检测的新算法,它用低维投影测量和信号重建取代了传统的相关操作和压缩处理,将大量运算转移到后期处理.该算法以噪声雷达所检测的目标空间分布满足稀疏性为前提；利用发射信号形成卷积矩阵,然后通过随机抽取卷积矩阵的行构建测量矩阵；并采用迭代收缩阈值算法实现目标信号重建.该文对算法作了详细的理论推导,形成完整的实现框架.仿真实验验证了算法的有效性,并分析了对处理结果影响较大的因素.该算法能够有效地重建目标,具有良好的运算效率.与时域相关法相比,大幅度减小了目标检测误差,有效抑制了输出旁瓣,并保持了信号的相位特性.\n真实关键词：['目标', '相关', '矩阵']\njieba_tfidf 关键词(前5)：['算法', '矩阵', '检测', '目标', '信号']\ntextrank 关键词(前5)：['算法', '信号', '目标', '压缩', '矩阵']\n"
    }
   ],
   "source": [
    "text, kwds = data_dev[10]\n",
    "print(text)\n",
    "print(\"真实关键词：\", kwds)\n",
    "print(\"jieba_tfidf 关键词(前5)：\", ht.extract_keywords(text, 5, method=\"jieba_tfidf\"))\n",
    "print(\"textrank 关键词(前5)：\", ht.extract_keywords(text, 5, method=\"textrank\"))"
   ]
  },
  {
   "source": [
    "每篇文章取前5个作为预测值，我们可以得到precision@5, recall@5, F1@5来评估算法的效果"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_keywords = 0\n",
    "pred_keywords = 0\n",
    "recall_new_word = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": "100%|██████████| 3000/3000 [00:29<00:00, 100.76it/s]\njieba Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\n"
    }
   ],
   "source": [
    "topK = 5\n",
    "ref_keywords, pred_keywords = 0, 0\n",
    "acc_count = 0\n",
    "for text, kwds in tqdm(data_dev):\n",
    "    ref_keywords += len(kwds)\n",
    "    pred_keywords += topK\n",
    "    preds = ht.extract_keywords(text, topK, method=\"jieba_tfidf\")\n",
    "    acc_count += len(set(kwds) & set(preds))\n",
    "prec = acc_count / pred_keywords\n",
    "recall = acc_count / ref_keywords\n",
    "f1 = 2*prec*recall/(prec+recall)\n",
    "print(f\"jieba Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "jieba Precison:0.1060, Recall:0.1478, F1:0.1235\n"
    }
   ],
   "source": [
    "print(f\"jieba Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\")"
   ]
  },
  {
   "source": [
    "Textrank调参"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": "100%|██████████| 3000/3000 [00:45<00:00, 66.11it/s]\ntextrank[block: doc, window:2, weighted:False] Precison:0.0942, Recall:0.1314, F1:0.1097\n100%|██████████| 3000/3000 [00:46<00:00, 64.20it/s]\ntextrank[block: doc, window:2, weighted:True] Precison:0.0955, Recall:0.1332, F1:0.1113\n100%|██████████| 3000/3000 [00:41<00:00, 71.53it/s]\ntextrank[block: doc, window:3, weighted:False] Precison:0.0948, Recall:0.1322, F1:0.1104\n100%|██████████| 3000/3000 [00:41<00:00, 65.70it/s]\ntextrank[block: doc, window:3, weighted:True] Precison:0.0945, Recall:0.1318, F1:0.1101\n100%|██████████| 3000/3000 [00:41<00:00, 72.11it/s]\ntextrank[block: doc, window:4, weighted:False] Precison:0.0944, Recall:0.1316, F1:0.1100\n100%|██████████| 3000/3000 [00:41<00:00, 71.65it/s]\ntextrank[block: doc, window:4, weighted:True] Precison:0.0939, Recall:0.1309, F1:0.1093\n100%|██████████| 3000/3000 [00:45<00:00, 66.37it/s]\ntextrank[block: sent, window:2, weighted:False] Precison:0.0931, Recall:0.1299, F1:0.1085\n100%|██████████| 3000/3000 [00:45<00:00, 65.93it/s]\ntextrank[block: sent, window:2, weighted:True] Precison:0.0945, Recall:0.1318, F1:0.1101\n100%|██████████| 3000/3000 [00:41<00:00, 53.28it/s]\ntextrank[block: sent, window:3, weighted:False] Precison:0.0936, Recall:0.1305, F1:0.1090\n100%|██████████| 3000/3000 [00:40<00:00, 73.21it/s]\ntextrank[block: sent, window:3, weighted:True] Precison:0.0929, Recall:0.1295, F1:0.1082\n100%|██████████| 3000/3000 [00:40<00:00, 73.50it/s]\ntextrank[block: sent, window:4, weighted:False] Precison:0.0931, Recall:0.1298, F1:0.1084\n100%|██████████| 3000/3000 [00:41<00:00, 72.45it/s]\ntextrank[block: sent, window:4, weighted:True] Precison:0.0925, Recall:0.1290, F1:0.1077\n"
    }
   ],
   "source": [
    "from itertools import product\n",
    "\n",
    "topK = 5\n",
    "block_types = [\"doc\", \"sent\"]\n",
    "window_sizes = [2, 3, 4]\n",
    "if_weighted = [False, True]\n",
    "for block_type, window, weighted in product(block_types, window_sizes, if_weighted):\n",
    "    ref_keywords, pred_keywords = 0, 0\n",
    "    acc_count = 0\n",
    "    for text, kwds in tqdm(data_dev):\n",
    "        ref_keywords += len(kwds)\n",
    "        pred_keywords += topK\n",
    "        preds = ht.extract_keywords(text, topK, method=\"textrank\", block_type=block_type, window=window, weighted=weighted)\n",
    "        acc_count += len(set(kwds) & set(preds))\n",
    "    prec = acc_count / pred_keywords\n",
    "    recall = acc_count / ref_keywords\n",
    "    f1 = 2*prec*recall/(prec+recall)\n",
    "    print(f\"textrank[block: {block_type}, window:{window}, weighted:{weighted}] Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\")"
   ]
  },
  {
   "source": [
    "textrank的最佳参数是 block: doc, window:2, weighted:True\n",
    "\n",
    "precision和recall与jieba_tfidf还是有差距，可能是因为后者拥有从大量语料库中统计得到的idf数据能起到一定帮助"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "## 测试集benchmark\n",
    "\n",
    "选取各个算法的最佳参数在测试集上获得最终表现"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "3000"
     },
     "metadata": {},
     "execution_count": 22
    }
   ],
   "source": [
    "data_test = []\n",
    "with open('CSL关键词预测/test.json', encoding='utf-8') as f:\n",
    "    for line in f:\n",
    "        tmp = json.loads(line)\n",
    "        data_test.append((tmp['abst'], tmp['keyword']))\n",
    "len(data_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": "100%|██████████| 3000/3000 [00:30<00:00, 99.11it/s]\njieba Precison:0.1035, Recall:0.1453, F1:0.1209\n"
    }
   ],
   "source": [
    "topK = 5\n",
    "ref_keywords, pred_keywords = 0, 0\n",
    "acc_count = 0\n",
    "for text, kwds in tqdm(data_test):\n",
    "    ref_keywords += len(kwds)\n",
    "    pred_keywords += topK\n",
    "    preds = ht.extract_keywords(text, topK, method=\"jieba_tfidf\")\n",
    "    acc_count += len(set(kwds) & set(preds))\n",
    "prec = acc_count / pred_keywords\n",
    "recall = acc_count / ref_keywords\n",
    "f1 = 2*prec*recall/(prec+recall)\n",
    "print(f\"jieba Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": "100%|██████████| 3000/3000 [00:45<00:00, 65.51it/s]\ntextrank Precison:0.0955, Recall:0.1342, F1:0.1116\n"
    }
   ],
   "source": [
    "topK = 5\n",
    "ref_keywords, pred_keywords = 0, 0\n",
    "acc_count = 0\n",
    "for text, kwds in tqdm(data_test):\n",
    "    ref_keywords += len(kwds)\n",
    "    pred_keywords += topK\n",
    "    preds = ht.extract_keywords(text, topK, method=\"textrank\", block_size=\"doc\", window=2, weighted=True)\n",
    "    acc_count += len(set(kwds) & set(preds))\n",
    "prec = acc_count / pred_keywords\n",
    "recall = acc_count / ref_keywords\n",
    "f1 = 2*prec*recall/(prec+recall)\n",
    "print(f\"textrank Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\")"
   ]
  },
  {
   "source": [
    "另，附上HarvestText与另一个流行的textrank的实现，[textrank4zh](https://github.com/letiantian/TextRank4ZH)的比较"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "随机噪声雷达通常利用时域相关完成脉冲压缩从而进行目标检测.该文根据压缩感知理论提出一种适用于噪声雷达目标检测的新算法,它用低维投影测量和信号重建取代了传统的相关操作和压缩处理,将大量运算转移到后期处理.该算法以噪声雷达所检测的目标空间分布满足稀疏性为前提；利用发射信号形成卷积矩阵,然后通过随机抽取卷积矩阵的行构建测量矩阵；并采用迭代收缩阈值算法实现目标信号重建.该文对算法作了详细的理论推导,形成完整的实现框架.仿真实验验证了算法的有效性,并分析了对处理结果影响较大的因素.该算法能够有效地重建目标,具有良好的运算效率.与时域相关法相比,大幅度减小了目标检测误差,有效抑制了输出旁瓣,并保持了信号的相位特性.\n真实关键词：['目标', '相关', '矩阵']\ntextrank4zh 关键词(前5)：['算法', '信号', '目标', '压缩', '运算']\n"
    }
   ],
   "source": [
    "from textrank4zh import TextRank4Keyword\n",
    "\n",
    "def textrank4zh(text, topK, window=2):\n",
    "    # same as used in ht\n",
    "    allowPOS = {'n', 'ns', 'nr', 'nt', 'nz', 'vn', 'v', 'an', 'a', 'i'}\n",
    "    tr4w = TextRank4Keyword(allow_speech_tags=allowPOS)\n",
    "    tr4w.analyze(text=text, lower=True, window=window)\n",
    "    return [item.word for item in tr4w.get_keywords(topK)]\n",
    "\n",
    "text, kwds = data_dev[10]\n",
    "print(text)\n",
    "print(\"真实关键词：\", kwds)\n",
    "print(\"textrank4zh 关键词(前5)：\", textrank4zh(text, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": "100%|██████████| 3000/3000 [02:12<00:00, 24.17it/s]\ntextrank4zh Precison:0.0836, Recall:0.1174, F1:0.0977\n"
    }
   ],
   "source": [
    "topK = 5\n",
    "ref_keywords, pred_keywords = 0, 0\n",
    "acc_count = 0\n",
    "for text, kwds in tqdm(data_test):\n",
    "    ref_keywords += len(kwds)\n",
    "    pred_keywords += topK\n",
    "    preds = textrank4zh(text, topK)\n",
    "    acc_count += len(set(kwds) & set(preds))\n",
    "prec = acc_count / pred_keywords\n",
    "recall = acc_count / ref_keywords\n",
    "f1 = 2*prec*recall/(prec+recall)\n",
    "print(f\"textrank4zh Precison:{prec:.4f}, Recall:{recall:.4f}, F1:{f1:.4f}\")"
   ]
  },
  {
   "source": [
    "HarvestText的textrank的实现在精度和速度上都有一定的优势。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "总结各个算法在CSL数据及上的结果：\n",
    "\n",
    "| 算法 | P@5 | R@5 | F@5 |\n",
    "| --- | --- | --- | --- |\n",
    "| textrank4zh | 0.0836 | 0.1174 | 0.0977 |\n",
    "| ht_textrank | 0.0955 | 0.1342 | 0.1116 |\n",
    "| ht_jieba_tfidf | **0.1035** | **0.1453** | **0.1209** |\n",
    "\n",
    "综上，HarvestText的关键词抽取功能\n",
    "- 把配置好参数的jieba_tfidf作为默认方法\n",
    "- 使用自己的textrank实现而不是用流行的textrank4zh。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}